import os
import sys
import pybedtools
import pysam


dataset, library = sys.argv[1:]

skip_targets = ('chrM', 'rRNA', 'tRNA', 'snRNA', 'scRNA', 'snoRNA', 'yRNA',
                'histone', 'scaRNA', 'snar', 'vRNA',
               )

keep_targets = ('RMRP', 'RPPH', 'TERC', 'MALAT1', 'snhg',
                'mRNA', 'lncRNA', 'gencode', 'fantomcat', 'genome')

skip_matures = ('spliced_mRNA', 'spliced_lncRNA', 'spliced_gencode')

skip_annotations = ('presnRNA', 'pretRNA', 'prescaRNA', 'presnoRNA',)


def parse_bamfile(path):
    current = None
    print("Reading", path)
    bamfile = pysam.AlignmentFile(path)
    for line in bamfile:
        if line.is_unmapped:
            continue
        target = line.get_tag("XT")
        if target in skip_targets:
            continue
        assert target in keep_targets
        try:
            annotation = line.get_tag("XA")
        except KeyError:
            pass
        else:
            assert annotation in skip_annotations
            continue
        if target in ("mRNA", "lncRNA", "gencode"):
            try:
                annotation = line.get_tag("XE")
            except KeyError:
                pass
            else:
                assert annotation in skip_matures
                continue
        try:
            length = line.get_tag("XL")
        except KeyError:
            pass
        else:
            if length < line.reference_length: # spliced
                continue
        multimap = line.get_tag("NH")
        if multimap != 1:
            continue
        if line.is_reverse:
            strand = '-'
            start = line.aend - 1
        else:
            strand = '+'
            start = line.pos
        chromosome = line.reference_name
        if (chromosome, start, strand) != current:
            if current is not None:
                target = ",".join(sorted(targets))
                score = str(count)
                fields = [current[0],       # chromosome
                          current[1],       # start
                          current[1] + 1,   # end
                          target,           # name
                          score,            # score
                          current[2]]       # strand
                interval = pybedtools.create_interval_from_list(fields)
                yield interval
            targets = set()
            count = 0.0
        current = chromosome, start, strand
        target = line.get_tag("XT")
        targets.add(target)
        count += (1.0/multimap)
    bamfile.close()
    target = ",".join(sorted(targets))
    score = str(count)
    fields = [current[0],       # chromosome
              current[1],       # start
              current[1] + 1,   # end
              target,           # name
              score,            # score
              current[2]]       # strand
    interval = pybedtools.create_interval_from_list(fields)
    yield interval

def parse_miseq_bamfile(path):
    current = None
    print("Reading", path)
    bamfile = pysam.AlignmentFile(path)
    for line1 in bamfile:
        line2 = next(bamfile)
        if line1.is_unmapped:
            assert line2.is_unmapped
            continue
        target = line1.get_tag("XT")
        if target in skip_targets:
            continue
        assert target in keep_targets
        try:
            annotation = line1.get_tag("XA")
        except KeyError:
            pass
        else:
            assert annotation in skip_annotations, annotation
            continue
        if target in ("mRNA", "lncRNA", "gencode"):
            try:
                annotation = line1.get_tag("XE")
            except KeyError:
                pass
            else:
                assert annotation in skip_matures
                continue
        multimap = line1.get_tag("NH")
        if multimap != 1:
            continue
        if line1.is_reverse:
            assert not line2.is_reverse
            strand = '-'
            start = line1.aend - 1
            reference_length = line1.aend - line2.pos
        else:
            assert line2.is_reverse
            strand = '+'
            start = line2.pos
            reference_length = line2.aend - line1.pos
        assert reference_length > 0
        try:
            length = line1.get_tag("XL")
        except KeyError:
            pass
        else:
            if length < reference_length: # spliced
                continue
        chromosome = line1.reference_name
        assert chromosome == line2.reference_name
        if (chromosome, start, strand) != current:
            if current is not None:
                target = ",".join(sorted(targets))
                score = str(count)
                fields = [current[0],       # chromosome
                          current[1],       # start
                          current[1] + 1,   # end
                          target,           # name
                          score,            # score
                          current[2]]       # strand
                interval = pybedtools.create_interval_from_list(fields)
                yield interval
            targets = set()
            count = 0.0
        current = chromosome, start, strand
        target = line1.get_tag("XT")
        targets.add(target)
        count += (1.0/multimap)
    bamfile.close()
    target = ",".join(sorted(targets))
    score = str(count)
    fields = [current[0],       # chromosome
              current[1],       # start
              current[1] + 1,   # end
              target,           # name
              score,            # score
              current[2]]       # strand
    interval = pybedtools.create_interval_from_list(fields)
    yield interval


directory = "/osc-fs_home/mdehoon/Data/CASPARs/%s/Mapping" % dataset
filename = "%s.bam" % library
path = os.path.join(directory, filename)
if dataset == "MiSeq":
    alignments = parse_miseq_bamfile(path)
else:
    alignments = parse_bamfile(path)
alignments = pybedtools.BedTool(alignments)
alignments = alignments.saveas()
print("Sorting")
alignments = alignments.sort()
print("Merging")
alignments = alignments.merge(s=True, d=-1, c=(4,5,6), o=('distinct','sum','distinct'))

filename = "%s.%s.ctss.bed" % (dataset, library)
print("Writing", filename)
alignments.saveas(filename)
